# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TFDS Dataset for Wildfire Simulator."""
import dataclasses as dc
import functools as ft
from typing import Any, Iterator, Mapping, Sequence, Tuple, Union

from clu import deterministic_data
import jax
from jax import numpy as jnp
from jax import random
import tensorflow as tf
import tensorflow_datasets as tfds

from wildfire_perc_sim import configurations
from wildfire_perc_sim import utils
from wildfire_perc_sim import wildfire_simulator

WildfireBatchType = Mapping[str, Union[jnp.ndarray, Sequence[jnp.ndarray]]]

NUM_TRAIN_SAMPLES = 5000
NUM_VAL_SAMPLES = 500
NUM_TEST_SAMPLES = 500


@dc.dataclass
class WildfireDatasetBuilderConfig(tfds.core.BuilderConfig):
  """WildfireDataset config.

  Attributes:
    stochastic: Setting it to `False` means all samples are rolled out for
      `maximum_rollout` timesteps and start at timestep `maximum_start_delay`.
    seed: Seed.
    maximum_rollout: Rollout the dynamical system for at most these many steps.
    maximum_start_delay: Start recording the states after at most
      `maximum_start_delay` steps.
  """
  seed: int = 0
  maximum_rollout: int = 15
  maximum_start_delay: int = 10
  field_dim: int = 64
  neighborhood_size: int = 5
  stochastic: bool = True
  wind: bool = True
  moisture: bool = True
  slope: bool = True
  dynamic_wind: bool = True


# Some common configurations
DynamicTimeSeriesWindBuilderConfig = ft.partial(
    WildfireDatasetBuilderConfig,
    moisture=False,
    slope=False,
    dynamic_wind=False)
FixedTimeSeriesWindBuilderConfig = ft.partial(
    DynamicTimeSeriesWindBuilderConfig, stochastic=False)
DynamicTimeSeriesDynamicWindBuilderConfig = ft.partial(
    WildfireDatasetBuilderConfig, moisture=False, slope=False)
FixedTimeSeriesDynamicWindBuilderConfig = ft.partial(
    DynamicTimeSeriesDynamicWindBuilderConfig, stochastic=False)
DynamicTimeSeriesDynamicWindSlopeBuilderConfig = ft.partial(
    WildfireDatasetBuilderConfig, moisture=False)
FixedTimeSeriesDynamicWindSlopeBuilderConfig = ft.partial(
    DynamicTimeSeriesDynamicWindSlopeBuilderConfig, stochastic=False)
DynamicTimeSeriesRealisticBuilderConfig = WildfireDatasetBuilderConfig
FixedTimeSeriesRealisticBuilderConfig = ft.partial(
    DynamicTimeSeriesRealisticBuilderConfig, stochastic=False)

COMMON_BUILDER_CONFIG_PARTIALS = [
    # Fixed Time Series Models
    FixedTimeSeriesWindBuilderConfig,
    FixedTimeSeriesDynamicWindBuilderConfig,
    FixedTimeSeriesDynamicWindSlopeBuilderConfig,
    FixedTimeSeriesRealisticBuilderConfig,
    # Dynamic Time Series Models
    DynamicTimeSeriesWindBuilderConfig,
    DynamicTimeSeriesDynamicWindBuilderConfig,
    DynamicTimeSeriesDynamicWindSlopeBuilderConfig,
    DynamicTimeSeriesRealisticBuilderConfig,
]

_DESCRIPTION = """
Wildfire Dataset contains time-series data generated by Rothermal Wildfire
Simulations. Dataset closely resembles the dataset proposed in
https://github.com/IhmeGroup/Wildfire-TPU but provides more information to
allow using the data for furthur simulations using the Rothermal model in
wildfire_simulator.

The dataset provides:

  - hidden_state: Hidden State at time step t_0.
  - hstate_sequence: Complete sequence of hidden states.
  - observation_sequence: Sequence of Observations.
  - kernel: Base Kernel for the Simulation.
  - alpha: Wind, Slope and Moisture coefficients (alphas).
  - burn_duration: Time before a cell gets completely burnt and the fire
      extinguishes.
  - nominal_ignition_heat: Heat needed to ignite each cell.

!!! warning
    Currently Dynamic Time Series BuilderConfigs are not implemented. If used
    they will yield the same results as their Fixed Time Series version.
"""


class WildfireDataset(tfds.core.GeneratorBasedBuilder):
  """DatasetBuilder for Wildfire dataset."""

  VERSION = tfds.core.Version('0.1.0')
  BUILDER_CONFIGS = [
      FixedTimeSeriesWindBuilderConfig('wind_fixed_ts'),
      FixedTimeSeriesDynamicWindBuilderConfig('dynamic_wind_fixed_ts'),
      FixedTimeSeriesDynamicWindSlopeBuilderConfig(
          'dynamic_wind_slope_fixed_ts'),
      FixedTimeSeriesRealisticBuilderConfig('realistic_fixed_ts'),
      DynamicTimeSeriesWindBuilderConfig('wind_dynamic_ts'),
      DynamicTimeSeriesDynamicWindBuilderConfig('dynamic_wind_dynamic_ts'),
      DynamicTimeSeriesDynamicWindSlopeBuilderConfig(
          'dynamic_wind_slope_dynamic_ts'),
      DynamicTimeSeriesRealisticBuilderConfig('realistic_dynamic_ts'),
  ]

  def _info(self):
    """Dataset Metadata."""
    field_shape = (self.builder_config.field_dim, self.builder_config.field_dim)
    hidden_channels = 9
    kernel_shape = wildfire_simulator.get_kernel(
        self.builder_config.neighborhood_size).shape
    return tfds.core.DatasetInfo(
        builder=self,
        description=_DESCRIPTION,
        features=tfds.features.FeaturesDict({
            'hidden_state':
                tfds.features.Tensor(
                    shape=(*field_shape, hidden_channels), dtype=tf.float32),
            'hstate_sequence':
                tfds.features.Sequence(
                    tfds.features.Tensor(
                        shape=(*field_shape, hidden_channels),
                        dtype=tf.float32),),
            'observation_sequence':
                tfds.features.Sequence(
                    tfds.features.Tensor(
                        shape=(*field_shape, 2), dtype=tf.float32)),
            'kernel':
                tfds.features.Tensor(shape=kernel_shape, dtype=tf.float32),
            'alpha':
                tfds.features.Tensor(shape=(3,), dtype=tf.float32),
            'burn_duration':
                tfds.features.Tensor(shape=(1,), dtype=tf.float32),
            'nominal_ignition_heat':
                tfds.features.Tensor(shape=(1,), dtype=tf.float32),
        }))

  def _split_generators(
      self, dl_manager):
    """Download the data and define splits."""
    prng = random.PRNGKey(self.builder_config.seed)
    key1, key2, key3 = random.split(prng, 3)
    return {
        'train': self._generate_examples(key1, NUM_TRAIN_SAMPLES, 'train'),
        'val': self._generate_examples(key2, NUM_VAL_SAMPLES, 'val'),
        'test': self._generate_examples(key3, NUM_TEST_SAMPLES, 'test'),
    }

  def _generate_examples(
      self,
      prng,
      num_samples,
      split_name,
  ):
    field_shape = (self.builder_config.field_dim, self.builder_config.field_dim)
    for i in range(num_samples):
      # Do later: rollout and start_delay should be controlled by the
      # stochastic parameter
      prng, key = random.split(prng)
      features = _generate_wildfire_sequence(
          key, self.builder_config.maximum_start_delay,
          self.builder_config.maximum_rollout, field_shape,
          self.builder_config.neighborhood_size, self.builder_config.wind,
          self.builder_config.dynamic_wind, self.builder_config.slope,
          self.builder_config.moisture)
      yield (split_name + '_' + str(i)), features


def _generate_wildfire_sequence(prng,
                                start_recording_at, record_for,
                                field_shape,
                                neighborhood_size, record_wind,
                                dynamic_wind, record_slope,
                                record_moisture):
  """Generate Wildfire Sequence."""
  if record_slope:
    prng, key1, key2 = random.split(prng, 3)
    terrain = configurations.terrain_slope(
        field_shape,
        random.uniform(key1, (1,), jnp.float32, minval=0, maxval=2 * jnp.pi),
        random.uniform(key2, (1,), jnp.float32, minval=0, maxval=2 * jnp.pi))
  else:
    terrain = jnp.ones(field_shape, jnp.float32)

  if record_wind:
    prng, key = random.split(prng)
    wind = configurations.wind_uniform(field_shape,
                                       random.normal(key, (2,)) * 3.0)
  else:
    wind = jnp.zeros(field_shape + (2,), jnp.float32)

  if record_moisture:
    prng, key = random.split(prng)
    moisture = configurations.moisture_random_normal(key, field_shape, 1.0,
                                                     0.25)
  else:
    moisture = jnp.zeros(field_shape, jnp.float32)

  nominal_ignition_heat = jnp.asarray([3.0], jnp.float32)
  wind_alpha = jnp.asarray([2.0], jnp.float32)
  moisture_alpha = jnp.asarray([1.0], jnp.float32)
  slope_alpha = jnp.asarray([0.7], jnp.float32)

  prng, key = random.split(prng)
  density = configurations.density_random_normal(key, field_shape, 1.0, 0.25)
  density = density * configurations.density_bool(key, field_shape, 0.6)

  simprop = wildfire_simulator.SimulatorProperties.create(
      neighborhood_size=neighborhood_size,
      boundary_condition=utils.BoundaryCondition.INFINITE,
      nominal_ignition_heat=nominal_ignition_heat,
      burn_duration=jnp.array([5.0]))
  simparams = wildfire_simulator.SimulatorParameters(
      slope_alpha=jnp.expand_dims(slope_alpha, 0),
      wind_alpha=jnp.expand_dims(wind_alpha, 0),
      moisture_alpha=jnp.expand_dims(moisture_alpha, 0))
  fieldprop = wildfire_simulator.FieldProperties(
      moisture=jnp.expand_dims(moisture, 0),
      terrain=jnp.expand_dims(terrain, 0),
      wind=jnp.expand_dims(wind, 0),
      density=jnp.expand_dims(density, 0))

  # Testing batched kernel generation
  dynamic_kernel = wildfire_simulator.parameterized_generate_dynamic_kernel(
      simprop, fieldprop, simparams)

  ignition_heat = wildfire_simulator.get_ignition_heat(
      simprop, fieldprop, simparams)

  prng, key = random.split(prng)
  lit_source = configurations.lit_from_pts(
      field_shape,
      configurations.location_random(key, ((1, field_shape[0] - 1),
                                           (1, field_shape[1] - 1))))
  lit_source = utils.set_border(lit_source, 0, 5)
  lit_source = jnp.expand_dims(lit_source, 0)

  bstate = wildfire_simulator.start_fire(lit_source, ignition_heat)

  hstate_sequence = []
  observation_sequence = []

  if dynamic_wind:
    raise ValueError('`dynamic_wind` is currently not supported.')

  for i in range(start_recording_at + record_for + 1):
    # Do Later: Stop recording once fire is extinguished. Need to do that
    # once we have models capable of handling variable length time series
    if i >= start_recording_at:
      hstate = jnp.concatenate(
          (wind, terrain[Ellipsis, None], moisture[Ellipsis, None], density[Ellipsis, None],
           bstate.heat[0, Ellipsis, None], bstate.fire[0, Ellipsis, None],
           bstate.lit[0, Ellipsis, None], bstate.burnt[0, Ellipsis, None]),
          axis=-1)
      if i == start_recording_at:
        hidden_state = hstate
      hstate_sequence.append(hstate)

    if i >= start_recording_at:
      observation_sequence.append(
          jnp.concatenate([bstate.lit[0, Ellipsis, None],
                           bstate.burnt[0, Ellipsis, None]], axis=-1))

    if i == start_recording_at + record_for:
      break

    bstate = wildfire_simulator.burn_step(bstate, simprop, dynamic_kernel,
                                          ignition_heat)

  return {
      'hidden_state': hidden_state,
      'hstate_sequence': hstate_sequence,
      'observation_sequence': observation_sequence,
      'kernel': simprop.base_kernel,
      'alpha': jnp.concatenate((wind_alpha, slope_alpha, moisture_alpha)),
      'burn_duration': simprop.burn_duration,
      'nominal_ignition_heat': nominal_ignition_heat,
  }


def preprocess_fn(features):
  """Preprocess the features from WildfireDataset."""
  # features['rng'] contains a RNG seed that can be used with
  # tf.random.stateless_* ops to do reproducible data augmentation. Do not use
  # the RNG seed more than once!
  features = jax.tree_util.tree_map(jnp.asarray, features)
  features['hidden_state'] = jnp.nan_to_num(
      jnp.asarray(features['hidden_state']), nan=0, posinf=utils.INF,
      neginf=-utils.INF)
  features[
      'observation_sequence'] = utils.restructure_distributed_sequence_data(
          features['observation_sequence'], True)
  features['hstate_sequence'] = utils.restructure_distributed_sequence_data(
      features['hstate_sequence'], True)
  # WARNING: All features here have to be compatible with JAX and TPUs. If your
  # dataset contains string features or variable length features you should
  # filter them here.
  return features


@dc.dataclass
class DatasetConfig:
  name: str = 'wildfire_dataset/wind_fixed_ts'
  train_batch_size: int = 128
  eval_batch_size: int = 128
  drop_remainder: bool = True
  data_dir: str = ''


def create_datasets(
    config, data_rng
):
  """Create datasets for training and evaluation.

  For the same data_rng and config this will return the same datasets. The
  datasets only contain stateless operations.

  Args:
    config: Configuration to use.
    data_rng: PRNGKey for seeding operations in the training dataset.

  Returns:
    A tuple with the dataset info, the training dataset and the evaluation
    dataset.
  """
  # Compute batch size per device from global batch size..
  if config.train_batch_size % jax.device_count() != 0:
    raise ValueError(f'Train Batch size ({config.train_batch_size}) must be '
                     f'divisible by the number of devices '
                     f'({jax.device_count()}).')
  if config.eval_batch_size % jax.device_count() != 0:
    raise ValueError(f'Eval Batch size ({config.eval_batch_size}) must be '
                     f'divisible by the number of devices '
                     f'({jax.device_count()}).')
  per_device_train_batch_size = config.train_batch_size // jax.device_count()
  per_device_eval_batch_size = config.eval_batch_size // jax.device_count()

  dataset_builder = tfds.builder(config.name, data_dir=config.data_dir)
  dataset_builder.download_and_prepare()

  train_split = tfds.split_for_jax_process(
      'train',
      process_index=jax.process_index(),
      process_count=jax.process_count())
  train_ds = deterministic_data.create_dataset(
      dataset_builder,
      split=train_split,
      num_epochs=None,
      shuffle=True,
      batch_dims=[jax.local_device_count(), per_device_train_batch_size],
      preprocess_fn=None,
      prefetch_size=tf.data.AUTOTUNE,
      rng=data_rng,
      drop_remainder=config.drop_remainder)

  eval_split = tfds.split_for_jax_process(
      'val',
      process_index=jax.process_index(),
      process_count=jax.process_count())
  eval_ds = deterministic_data.create_dataset(
      dataset_builder,
      split=eval_split,
      num_epochs=1,
      shuffle=False,
      batch_dims=[jax.local_device_count(), per_device_eval_batch_size],
      preprocess_fn=None,
      prefetch_size=tf.data.AUTOTUNE,
      drop_remainder=config.drop_remainder)

  test_split = tfds.split_for_jax_process(
      'test',
      process_index=jax.process_index(),
      process_count=jax.process_count())
  test_ds = deterministic_data.create_dataset(
      dataset_builder,
      split=test_split,
      num_epochs=1,
      shuffle=False,
      batch_dims=[jax.local_device_count(), per_device_eval_batch_size],
      preprocess_fn=None,
      prefetch_size=tf.data.AUTOTUNE,
      drop_remainder=config.drop_remainder)

  return dataset_builder.info, train_ds, eval_ds, test_ds
